{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一、配置环境变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from dotenv import dotenv_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ZHIPUAI_API_KEY\n",
      "DASHSCOPE_API_KEY\n"
     ]
    }
   ],
   "source": [
    "env_variables = dotenv_values('.env')\n",
    "for var in env_variables:\n",
    "    print(var)\n",
    "    os.environ[var] = env_variables[var]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'sk-f7e5ff2257cd4cff95eaedffb9ddb35e'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.getenv(\"DASHSCOPE_API_KEY\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 二、配置LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!/usr/bin/env python\n",
    "# -*- coding: utf-8 -*-\n",
    "\n",
    "import os\n",
    "from typing import Dict, List, Optional, Tuple, Union\n",
    "\n",
    "from zhipuai import ZhipuAI\n",
    "import dashscope\n",
    "import random\n",
    "from http import HTTPStatus\n",
    "import dashscope\n",
    "from dashscope import Generation  # 建议dashscope SDK 的版本 >= 1.14.0\n",
    "\n",
    "PROMPT_TEMPLATE = dict(\n",
    "    RAG_PROMPT_TEMPALTE=\"\"\"使用以上下文来回答用户的问题。如果你不知道答案，就说你不知道。总是使用中文回答。\n",
    "        问题: {question}\n",
    "        可参考的上下文：\n",
    "        ···\n",
    "        {context}\n",
    "        ···\n",
    "        如果给定的上下文无法让你做出回答，请回答数据库中没有这个内容，你不知道。\n",
    "        有用的回答:\"\"\",\n",
    "    InternLM_PROMPT_TEMPALTE=\"\"\"先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案，就说你不知道。总是使用中文回答。\n",
    "        问题: {question}\n",
    "        可参考的上下文：\n",
    "        ···\n",
    "        {context}\n",
    "        ···\n",
    "        如果给定的上下文无法让你做出回答，请回答数据库中没有这个内容，你不知道。\n",
    "        有用的回答:\"\"\"\n",
    ")\n",
    "\n",
    "\n",
    "class BaseModel:\n",
    "    def __init__(self, path: str = '') -> None:\n",
    "        self.path = path\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        pass\n",
    "\n",
    "    def load_model(self):\n",
    "        pass\n",
    "\n",
    "class GLM4Chat(BaseModel):\n",
    "    def __init__(self, path: str = '', model: str = \"glm-4\") -> None:\n",
    "        super().__init__(path)\n",
    "        self.model = model\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        client = ZhipuAI(api_key=os.getenv(\"ZHIPUAI_API_KEY\"))  # 填写您自己的APIKey\n",
    "        history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})\n",
    "        response = client.chat.completions.create(\n",
    "            model=\"glm-4\",  # 填写需要调用的模型名称\n",
    "            messages=history\n",
    "        )\n",
    "        return response.choices[0].message\n",
    "\n",
    "class QwenChat(BaseModel):\n",
    "    def __init__(self, path: str = '', model: str = \"qwen-turbo\") -> None:\n",
    "        super().__init__(path)\n",
    "        dashscope.api_key = os.getenv(\"DASHSCOPE_API_KEY\")\n",
    "        self.model = model\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})\n",
    "        response = Generation.call(model=\"qwen-turbo\",\n",
    "                                messages=history,\n",
    "                                # 设置随机数种子seed，如果没有设置，则随机数种子默认为1234\n",
    "                                seed=random.randint(1, 10000),\n",
    "                                # 将输出设置为\"message\"格式\n",
    "                                result_format='message')\n",
    "        if response.status_code == HTTPStatus.OK:\n",
    "            return response.output.choices[0].message[\"content\"]\n",
    "        else:\n",
    "            print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (\n",
    "                response.request_id, response.status_code,\n",
    "                response.code, response.message\n",
    "            ))\n",
    "\n",
    "\n",
    "class OpenAIChat(BaseModel):\n",
    "    def __init__(self, path: str = '', model: str = \"gpt-3.5-turbo-1106\") -> None:\n",
    "        super().__init__(path)\n",
    "        self.model = model\n",
    "\n",
    "    def chat(self, prompt: str, history: List[dict], content: str) -> str:\n",
    "        from openai import OpenAI\n",
    "        client = OpenAI()   \n",
    "        history.append({'role': 'user', 'content': PROMPT_TEMPLATE['RAG_PROMPT_TEMPALTE'].format(question=prompt, context=content)})\n",
    "        response = client.chat.completions.create(\n",
    "            model=self.model,\n",
    "            messages=history,\n",
    "            max_tokens=150,\n",
    "            temperature=0.1\n",
    "        )\n",
    "        return response.choices[0].message.content\n",
    "\n",
    "class InternLMChat(BaseModel):\n",
    "    def __init__(self, path: str = '') -> None:\n",
    "        super().__init__(path)\n",
    "        self.load_model()\n",
    "\n",
    "    def chat(self, prompt: str, history: List = [], content: str='') -> str:\n",
    "        prompt = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=prompt, context=content)\n",
    "        response, history = self.model.chat(self.tokenizer, prompt, history)\n",
    "        return response\n",
    "\n",
    "\n",
    "    def load_model(self):\n",
    "        import torch\n",
    "        from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)\n",
    "        self.model = AutoModelForCausalLM.from_pretrained(self.path, torch_dtype=torch.float16, trust_remote_code=True).cuda()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'你好！有什么我可以帮助你的吗？'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "qwenchat = QwenChat()\n",
    "qwenchat.chat(history=[], prompt=\"hello\", content=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 三、设置知识库FaissVectorStore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "from typing import Dict, List, Optional, Tuple, Union\n",
    "import json\n",
    "\n",
    "import faiss\n",
    "\n",
    "from RAG.Embeddings import BaseEmbeddings, OpenAIEmbedding, JinaEmbedding, ZhipuEmbedding\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "from RAG.VectorBase import VectorStore\n",
    "\n",
    "\n",
    "class FaissVectorStore(VectorStore):\n",
    "\n",
    "    def __init__(self, document: List[str] = ['']) -> None:\n",
    "        super().__init__(document)\n",
    "        self.document = document\n",
    "\n",
    "    def get_vector(self, EmbeddingModel: BaseEmbeddings) -> List[List[float]]:\n",
    "        self.vectors = []\n",
    "        for doc in tqdm(self.document, desc=\"Calculating embeddings\"):\n",
    "            self.vectors.append(EmbeddingModel.get_embedding(doc))\n",
    "        return self.vectors\n",
    "\n",
    "    def persist(self, path: str = 'storage'):\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "        with open(f\"{path}/doecment.json\", 'w', encoding='utf-8') as f:\n",
    "            json.dump(self.document, f, ensure_ascii=False)\n",
    "        if self.vectors:\n",
    "            with open(f\"{path}/vectors.json\", 'w', encoding='utf-8') as f:\n",
    "                json.dump(self.vectors, f)\n",
    "\n",
    "    def load_vector(self, path: str = 'storage'):\n",
    "        with open(f\"{path}/vectors.json\", 'r', encoding='utf-8') as f:\n",
    "            self.vectors = json.load(f)\n",
    "        with open(f\"{path}/doecment.json\", 'r', encoding='utf-8') as f:\n",
    "            self.document = json.load(f)\n",
    "\n",
    "    def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:\n",
    "        return BaseEmbeddings.cosine_similarity(vector1, vector2)\n",
    "\n",
    "    def generateIndexFlatL2(self):\n",
    "        dim, measure = 768, faiss.METRIC_L2\n",
    "        param = 'Flat'\n",
    "        vectors = np.array([vector for vector in self.vectors]).astype('float32')\n",
    "        index = faiss.index_factory(dim, param, measure)\n",
    "        index.add(vectors)  # 将向量库中的向量加入到index中\n",
    "        return index\n",
    "\n",
    "\n",
    "    def query(self, query: str, EmbeddingModel: BaseEmbeddings, k: int = 1) -> List[str]:\n",
    "\n",
    "        start_time = time.time()\n",
    "        index = self.generateIndexFlatL2()\n",
    "        end_time = time.time()\n",
    "        print(' 构建索引 cost %f second' % (end_time - start_time))\n",
    "        query_vector = np.array([EmbeddingModel.get_embedding(query)]).astype(\"float32\")\n",
    "\n",
    "        end_time = time.time()\n",
    "        D, I = index.search(query_vector, k)  # xq为待检索向量，返回的I为每个待检索query最相似TopK的索引list，D为其对应的距离\n",
    "        print(' 检索 cost %f second' % (time.time() - end_time))\n",
    "        return np.array(self.document)[I[0,:]].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector = FaissVectorStore()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "vector.load_vector('../../storage/history_24_1') # 加载本地的数据库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "768"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(vector.vectors[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四、定义embedding算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from RAG.PaddleEmbedding import PaddleEmbedding\n",
    "embedding = PaddleEmbedding() # 创建EmbeddingModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = '介绍一下秦二世'\n",
    "content = vector.query(question, EmbeddingModel=embedding, k=3)[0]\n",
    "print(f'知识库输出：{content}')\n",
    "chat = QwenChat()\n",
    "print(chat.chat(question, [], content))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
